package com.jwong.common.web.interceptor;

import com.jwong.common.context.PrimaryKeyContext;
import lombok.extern.slf4j.Slf4j;
import org.springframework.security.oauth2.common.OAuth2AccessToken;
import org.springframework.security.oauth2.provider.OAuth2Authentication;
import org.springframework.security.oauth2.provider.token.AccessTokenConverter;
import org.springframework.security.oauth2.provider.token.ResourceServerTokenServices;
import org.springframework.util.StringUtils;
import org.springframework.web.servlet.HandlerInterceptor;

import javax.annotation.Nonnull;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.util.Map;
import java.util.Optional;

/**
 * 拦截accessToken，解析UserID，并将UserID绑定到CommonContext
 */
@Slf4j
public class AccessTokenHandlerInterceptor implements HandlerInterceptor {

    private static final String ACCESS_TOKEN_HEADER_NAME = "Authorization";
    private static final String ACCESS_TOKEN_TYPE = "Bearer";
    private static final String USER_PRIMARY_KEY = "user_primary_key";

    private final ResourceServerTokenServices resourceServerTokenServices;
    private final AccessTokenConverter accessTokenConverter;

    public AccessTokenHandlerInterceptor(ResourceServerTokenServices resourceServerTokenServices,
                                         AccessTokenConverter accessTokenConverter) {
        this.resourceServerTokenServices = resourceServerTokenServices;
        this.accessTokenConverter = accessTokenConverter;
    }

    /**
     * 绑定
     */
    @Override
    public boolean preHandle(@Nonnull HttpServletRequest request, @Nonnull HttpServletResponse response,
                             @Nonnull Object handler) throws Exception {

        Long userPrimaryKey = PrimaryKeyContext.get();
        if (userPrimaryKey == null) {
            String accessToken = request.getHeader(ACCESS_TOKEN_HEADER_NAME);
            if (StringUtils.hasText(accessToken)) {
                if (StringUtils.startsWithIgnoreCase(accessToken, ACCESS_TOKEN_TYPE)) {
                    accessToken = StringUtils.delete(accessToken, ACCESS_TOKEN_TYPE).trim();
                }

                OAuth2AccessToken token = resourceServerTokenServices.readAccessToken(accessToken);
                OAuth2Authentication authentication = resourceServerTokenServices.loadAuthentication(token.getValue());
                Map<String, ?> accessTokenMap = accessTokenConverter.convertAccessToken(token, authentication);
                Optional.of(accessTokenMap.get(USER_PRIMARY_KEY)).ifPresent(value ->
                        PrimaryKeyContext.bind(Long.parseLong(value.toString())));

                if (log.isDebugEnabled()) {
                    log.debug("bind userPrimaryKey: {} to PrimaryKeyContext", userPrimaryKey);
                }
            } else {
                log.warn("can not get accessToken, bind userPrimaryKey fail");
            }
        }
        return true;
    }

    /**
     * 解绑
     */
    @Override
    public void afterCompletion(@Nonnull HttpServletRequest request, @Nonnull HttpServletResponse response,
                                @Nonnull Object handler, Exception ex) throws Exception {

        Long userPrimaryKey = PrimaryKeyContext.get();
        if (userPrimaryKey == null) {
            return;
        }
        Long unbindPrimaryKeyContext = PrimaryKeyContext.unbind();
        if (log.isDebugEnabled()) {
            log.debug("unbind userPrimaryKey: {} from PrimaryKeyContext", unbindPrimaryKeyContext);
        }
    }

}
